#coding=utf-8

import numpy
from dcnn import *
import logging, time
from itertools import izip
import copy

train_data_file = '../data/douban/train2.data'
train_label_file = '../data/douban/train2.label'
train_sentences = LineSentence(train_data_file)
train_labels = numpy.fromfile(train_label_file, sep='\n', dtype=numpy.int32)

test_data_file = '../data/douban/test2.data'
test_label_file = '../data/douban/test2.label'
test_sentences = LineSentence(test_data_file)
test_labels = numpy.fromfile(test_label_file, sep='\n', dtype=numpy.int32)

pre_train_data_file = '/home/hadoop/corpus/sougou_news/seg'
pre_train_sentences = LineSentence(pre_train_data_file)


def test_train():
    dcnn = DCNN(sentences=train_sentences, workers=1, full_con_layer_size=20)
    dcnn.train(train_sentences, train_labels, test_sentences, test_labels)


def test_dynamic_kmax_pooling():
    kmax_pooling = DynamicKMaxPooling(2, 5, 2, 1, 3)
    input = numpy.random.rand(100).reshape(2, 5, 10)
    print 'input:\n', input
    print 'sentence_len:\n', input.shape[2]
    output = kmax_pooling.forward(input, input.shape[2])
    print 'output:\n', output
    back_grad = kmax_pooling.backward(output)
    print 'back_grad:\n', back_grad


def test_folding():
    folding = Folding(2)
    input = numpy.random.rand(80).reshape(2, 4, 10)
    print 'input:\n', input
    output = folding.forward(input)
    print 'output:\n', output
    back_grad = folding.backward(output, back_linear=True)
    print 'back_grad:\n', back_grad


def test_grad():
    dcnn = DCNNDeep(wordvec_dim=4, sentences=train_sentences, n_filters=[2, 2], min_count=100, len_of_sentence_limit=20)
    epsilon = 1e-5
    for j, (sentence, y) in enumerate(izip(train_sentences, train_labels)):
        if j < 5:
            dcnn.forward(sentence)
            dcnn.backward(y)
            grads = dcnn.get_grads()
            params = dcnn.get_params()
            for i in xrange(len(params) - 1, 0, -1):
                #for i in xrange(len(params)):
                shape = params[i].shape
                print shape
                flatten_param = params[i].flatten()
                flatten_grad = grads[i].flatten()
                num_grad = numpy.empty(flatten_grad.shape, dtype=numpy.float32)
                for j in xrange(len(flatten_param)):
                    old_param = flatten_param[j]
                    flatten_param[j] = old_param + epsilon
                    params[i] = flatten_param.reshape(shape)
                    dcnn.set_params(params)
                    dcnn.forward(sentence)
                    pos = dcnn.get_loss(y)
                    flatten_param[j] = old_param - epsilon
                    params[i] = flatten_param.reshape(shape)
                    dcnn.set_params(params)
                    dcnn.forward(sentence, y)
                    neg = dcnn.get_loss(y)
                    num_grad[j] = (pos - neg) / (2 * epsilon)
                    flatten_param[j] = old_param
                    params[i] = flatten_param.reshape(shape)
                    dcnn.set_params(params)
                # print flatten_grad, '\n', num_grad
                diff = numpy.linalg.norm(flatten_grad - num_grad) / numpy.linalg.norm(flatten_grad + num_grad)
                print '%d/%d' % (i + 1, len(params)), diff
                print '=' * 10
            print '*' * 20
            dcnn.update()

def test_most_similar():
    dcnn = DCNN.load('../model/dcnn_model.pkl')
    test_s = "难题"
    for word, dist in dcnn.most_similar(test_s):
        print word, dist

def test_pre_train():
    dcnn = DCNN(pre_train=True, pre_train_sentences=pre_train_sentences, sentences=train_sentences)
    dcnn.train(train_sentences, train_labels, test_sentences, test_labels)

if __name__ == '__main__':
    logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
    test_grad()